{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "name": "Gated Linear Units and Variants",
   "provenance": [],
   "collapsed_sections": [],
   "toc_visible": true
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "accelerator": "GPU"
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AYV_dMVDxyc2"
   },
   "source": [
    "[![Github](https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social)](https://github.com/labmlai/annotated_deep_learning_paper_implementations)\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/glu_variants/simple.ipynb)                    \n",
    "\n",
    "## Gated Linear Units and Variants\n",
    "\n",
    "This trains a simple [transformer](https://nn.labml.ai/transformers/) model for auto-regression.\n",
    "We try different variants for the [position-wise feedforward network](https://nn.labml.ai/transformers/feed_forward.html).\n",
    "\n",
    "Annotated trainer code is at [`simple.py`](https://nn.labml.ai/transformers/glu_variants/simple.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AahG_i2y5tY9"
   },
   "source": [
    "Install the `labml-nn` package"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "ZCzmCrAIVg0L",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "2de76edb-9911-496d-9f8c-281dad6f5680"
   },
   "source": [
    "!pip install labml-nn"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SE2VUQ6L5zxI"
   },
   "source": [
    "Imports"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "0hJXx_g0wS2C"
   },
   "source": [
    "import dataclasses\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from labml import experiment\n",
    "from labml_nn.transformers.glu_variants.simple import Configs, Trainer"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Lpggo0wM6qb-"
   },
   "source": [
    "Create an experiment"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "bFcr9k-l4cAg"
   },
   "source": [
    "experiment.create(name=\"glu_variants\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-OnHLi626tJt"
   },
   "source": [
    "Initialize configurations"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "Piz0c5f44hRo"
   },
   "source": [
    "conf = Configs()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wwMzCqpD6vkL"
   },
   "source": [
    "Set experiment configurations and assign a configurations dictionary to override configurations"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 17
    },
    "id": "e6hmQhTw4nks",
    "outputId": "77eca625-7205-49ea-f275-23f2710c4d84"
   },
   "source": [
    "experiment.configs(dataclasses.asdict(conf))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DHyNvXfnzeWQ"
   },
   "source": [
    "Create [`Trainer`](https://nn.labml.ai/transformers/glu_variants/simple.html)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "59ZeTv5SzcVe"
   },
   "source": [
    "trainer = Trainer(conf)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EvI7MtgJ61w5"
   },
   "source": [
    "Set PyTorch models for loading and saving"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "GDlt7dp-5ALt"
   },
   "source": [
    "experiment.add_pytorch_models({'model': trainer.model})"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KJZRf8527GxL"
   },
   "source": [
    "Start the experiment and run the training loop."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 255
    },
    "id": "aIAWo7Fw5DR8",
    "outputId": "18b8b334-f9e7-458b-f900-5828b4f9a5c8"
   },
   "source": [
    "with experiment.start():\n",
    "    trainer.train()"
   ],
   "outputs": [],
   "execution_count": null
  }
 ]
}
